{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: progress in /usr/local/lib/python3.7/site-packages (1.6)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "! pip install progress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "import argparse, os, shutil, time, random, math\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.parallel\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as data\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn.functional as F\n",
    "import losses\n",
    "from datasets.cifar100 import get_cifar100\n",
    "from train.train import get_train_fn, get_update_score_fn\n",
    "from train.validate import get_valid_fn\n",
    "from models.net import get_model\n",
    "from losses.loss import get_loss, get_optimizer\n",
    "from utils.config import parse_args, reproducibility, dataset_argument\n",
    "from utils.common import make_imb_data, save_checkpoint, adjust_learning_rate\n",
    "\n",
    "from tqdm.autonotebook import tqdm\n",
    "from utils.logger import logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "---> ---cifar100---\n",
      "---> Argument\n",
      "    └> network     : resnet32\n",
      "    └> epochs      : 200\n",
      "    └> batch_size  : 128\n",
      "    └> update_epoch: 1\n",
      "    └> lr          : 0.1\n",
      "    └> momentum    : 0.9\n",
      "    └> wd          : 0.0002\n",
      "    └> nesterov    : False\n",
      "    └> scheduler   : warmup\n",
      "    └> warmup      : 5\n",
      "    └> aug_prob    : 0.5\n",
      "    └> cutout      : False\n",
      "    └> cmo         : False\n",
      "    └> posthoc_la  : False\n",
      "    └> cuda        : False\n",
      "    └> randaug     : False\n",
      "    └> num_test    : 8\n",
      "    └> verbose     : False\n",
      "    └> out         : ./results/cifar100/ce@N_500_ir_100/\n",
      "    └> data_dir    : /input/dataset/\n",
      "    └> workers     : 4\n",
      "    └> seed        : None\n",
      "    └> gpu         : None\n",
      "    └> dataset     : cifar100\n",
      "    └> num_max     : 500\n",
      "    └> imb_ratio   : 100\n",
      "    └> loss_fn     : ce\n",
      "    └> num_experts : 3\n",
      "    └> num_class   : 100\n"
     ]
    }
   ],
   "source": [
    "args = parse_args(run_type = 'jupyter')\n",
    "reproducibility(args.seed)\n",
    "args = dataset_argument(args)\n",
    "best_acc = 0 # best test accuracy\n",
    "\n",
    "args.logger = logger(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.imb_ratio = 100\n",
    "args.aug_type = 'many'\n",
    "args.data_dir = '/input/dataset/'\n",
    "args.aug_prob = 0.5\n",
    "# args.data_dir = '/home/work/cuda/dataset'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[500, 477, 455, 434, 415, 396, 378, 361, 344, 328, 314, 299, 286, 273, 260, 248, 237, 226, 216, 206, 197, 188, 179, 171, 163, 156, 149, 142, 135, 129, 123, 118, 112, 107, 102, 98, 93, 89, 85, 81, 77, 74, 70, 67, 64, 61, 58, 56, 53, 51, 48, 46, 44, 42, 40, 38, 36, 35, 33, 32, 30, 29, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5]\n",
      "==> Preparing imbalanced CIFAR-100\n",
      "Files already downloaded and verified\n",
      "Magnitude set = tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
      "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], dtype=torch.int32)\n",
      "Operation set = tensor([ 0,  1,  2,  3,  3,  4,  5,  6,  6,  7,  8,  9,  9, 10, 11, 11, 12, 13,\n",
      "        14, 14, 15, 16, 17, 17, 18, 19, 20, 20, 21, 22, 22], dtype=torch.int32)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "#Train: 10847, #Test: 9000\n"
     ]
    }
   ],
   "source": [
    "global best_acc\n",
    "\n",
    "try:\n",
    "    assert args.num_max <= 50000. / args.num_class\n",
    "except AssertionError:\n",
    "    args.num_max = int(50000 / args.num_class)\n",
    "\n",
    "N_SAMPLES_PER_CLASS = make_imb_data(args.num_max, args.num_class, args.imb_ratio)\n",
    "\n",
    "if args.dataset == 'cifar100':\n",
    "    print(f'==> Preparing imbalanced CIFAR-100')\n",
    "    # trainset, allset = get_cifar100(os.path.join(args.data_dir, 'cifar100/'), N_SAMPLES_PER_CLASS, cutout=args.cutout)\n",
    "    trainset, allset, devset, testset = get_cifar100(os.path.join(args.data_dir, 'cifar100/'), imb_ratio = args.imb_ratio, \\\n",
    "                                                     cutout = args.cutout,  contrast = args.loss_fn == 'ncl', \\\n",
    "                                                     randaug = args.randaug, aug_prob = args.aug_prob)\n",
    "    N_SAMPLES_PER_CLASS = trainset.img_num_list\n",
    "    \n",
    "    \n",
    "trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False)\n",
    "allloader = data.DataLoader(allset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fix Feature extractor check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_fe(many_aug, med_aug, few_aug, alternate=False):\n",
    "\n",
    "    #  TODO: Augmentation\n",
    "    per_class_num = torch.tensor(N_SAMPLES_PER_CLASS)\n",
    "    many = torch.where(per_class_num > 100)[0].numpy().tolist()\n",
    "    med = torch.where((per_class_num <= 100) & (per_class_num >=20))[0].numpy().tolist()\n",
    "    few = torch.where(per_class_num < 20)[0].numpy().tolist()\n",
    "    \n",
    "    if alternate:\n",
    "        aug_state = torch.zeros(len(trainloader.dataset))\n",
    "        for idx, label in enumerate(trainloader.dataset.targets):\n",
    "            if label in many:\n",
    "                aug_state[idx] = many_aug\n",
    "            elif label in med:\n",
    "                aug_state[idx] = med_aug\n",
    "            elif label in few:\n",
    "                aug_state[idx] = few_aug\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "                \n",
    "        orig_state = torch.zeros(len(trainloader.dataset))\n",
    "\n",
    "    # Model\n",
    "    print (\"==> creating {}\".format(args.network))\n",
    "    model = get_model(args, N_SAMPLES_PER_CLASS)\n",
    "    args.loss_fn = 'ce'\n",
    "    train_criterion = get_loss(args, N_SAMPLES_PER_CLASS)\n",
    "    criterion = nn.CrossEntropyLoss(reduction='mean') # For test, validation \n",
    "    optimizer = get_optimizer(args, model)\n",
    "    train = get_train_fn(args)\n",
    "    validate = get_valid_fn(args)\n",
    "\n",
    "    test_accs, dev_accs, all_accs = [],[],[]\n",
    "    best_acc = 0\n",
    "    # for epoch in range(args.epochs):\n",
    "    for epoch in tqdm(range(args.epochs)):\n",
    "        lr = adjust_learning_rate(optimizer, epoch, None, args)\n",
    "        if alternate and epoch%2==0:\n",
    "            trainloader.dataset.curr_state = orig_state\n",
    "        elif alternate and epoch%2==1:\n",
    "            trainloader.dataset.curr_state = aug_state\n",
    "        train_loss = train(args, trainloader, model, optimizer, train_criterion, epoch, None, None)\n",
    "        all_loss, all_acc, all_cls = validate(args, allloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='All Valid')\n",
    "        \n",
    "        if best_acc < all_acc:\n",
    "            best_acc = all_acc\n",
    "            best_cls = all_cls\n",
    "        \n",
    "    # Print the final results\n",
    "    args.logger(f'Final Performance...', level=1)\n",
    "    args.logger(f'best bAcc (all): {best_acc}', level=2)\n",
    "    args.logger(f'best bAcc (many): {best_cls[0]}', level=2)\n",
    "    args.logger(f'best bAcc (med): {best_cls[1]}', level=2)\n",
    "    args.logger(f'best bAcc (few): {best_cls[2]}', level=2)\n",
    "        \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adjust_learning_rate_crt(optimizer, epoch, args):\n",
    "    epoch = epoch + 1\n",
    "    if epoch > 8:\n",
    "        lr = args.lr * 0.001\n",
    "    elif epoch > 6:\n",
    "        lr = args.lr * 0.01\n",
    "    elif epoch > 3:\n",
    "        lr = args.lr * 0.1\n",
    "    else:\n",
    "        lr = args.lr\n",
    "    for param_group in optimizer.param_groups:\n",
    "        param_group['lr'] = lr\n",
    "    return lr\n",
    "    \n",
    "def train_fc(model):\n",
    "    # Model\n",
    "    import torch.optim as optim\n",
    "    args.loss_fn = 'bs'\n",
    "    \n",
    "    train_criterion = get_loss(args, N_SAMPLES_PER_CLASS)\n",
    "    criterion = nn.CrossEntropyLoss(reduction='mean') # For test, validation \n",
    "    train = get_train_fn(args)\n",
    "    validate = get_valid_fn(args)\n",
    "\n",
    "    # TODO : Freeze implementation\n",
    "    optimizer = optim.SGD(model.linear.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)\n",
    "    \n",
    "    # TODO : Augmentation change\n",
    "    for i, state in enumerate(trainloader.dataset.curr_state):\n",
    "        trainloader.dataset.curr_state[i] = 0\n",
    "        \n",
    "    test_accs, dev_accs, all_accs = [],[],[]\n",
    "    best_acc = 0\n",
    "    # for epoch in range(args.epochs):\n",
    "    for epoch in tqdm(range(10)): # args.epochs\n",
    "        lr = adjust_learning_rate_crt(optimizer, epoch, args)\n",
    "        train_loss = train(args, trainloader, model, optimizer, train_criterion, epoch, None, None)\n",
    "        all_loss, all_acc, all_cls = validate(args, allloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='All Valid')\n",
    "        # dev_loss, dev_acc, dev_cls = validate(args, devloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='Dev Valid')\n",
    "        # test_loss, test_acc, test_cls = validate(args, testloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='Test Valid')\n",
    "        \n",
    "        if best_acc < all_acc:\n",
    "            best_acc = all_acc\n",
    "            best_cls = all_cls\n",
    "        \n",
    "    \n",
    "    # Print the final results\n",
    "    args.logger(f'Final Performance...', level=1)\n",
    "    args.logger(f'best bAcc (all): {best_acc}', level=2)\n",
    "    args.logger(f'best bAcc (many): {best_cls[0]}', level=2)\n",
    "    args.logger(f'best bAcc (med): {best_cls[1]}', level=2)\n",
    "    args.logger(f'best bAcc (few): {best_cls[2]}', level=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==> creating resnet32\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3ef30a73fd045fa960e7d0a07e6ae30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# many: 5, med: 0, few: 0\n",
    "model1 = train_fe(5, 0, 0, alternate=False)\n",
    "train_fc(model1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# many: 0, med: 0, few: 5\n",
    "model2 = train_fe(0, 0, 5, alternate=False)\n",
    "train_fc(model2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# many: 0, med: 0, few: 0\n",
    "model3 = train_fe(0, 0, 0, alternate=False)\n",
    "train_fc(model3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
  },
  "kernelspec": {
   "display_name": "Full on Python 3.7 (GPU)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
